# Makes Figures 7 and 8, summarizing the counterfactuals
# Note that the results are "hard coded" in this script,
# so you need to use the other code to actually replicate the results!

# See "do_counterfactuals_insample.py" for the insample calculation
# See "do_counterfactuals_future.py" for the out of sample calculation
# Use "do_counterfactuals_bootstrap.py" to do the bootstrap draws

# Jacob Moscona and Karthik Sastry
# This version: last edited on September 28, 2022

##
## Importing libraries and files
##

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter
sns.set_color_codes()
import matplotlib.transforms as mtrans
import pandas as pd
import os

wd = '/home/karthik/Dropbox (MIT)/climate_crops/QJE_Submission/Replication/Data/'
os.chdir(wd)

## Useful functions to put the label on the bar
def autolabel(rects,axis,clr='white',neg=True):
    
    yht = axis.get_ylim()[1]
    yht = np.abs(yht)
    
    for rect in rects:
        height = rect.get_height()
        if neg:
            nm = '{:.1f}'.format(height)
            v0 = -15
        else:
            nm = '{:.1f}'.format(height)
            v0 = -20
        axis.annotate(nm,
                    xy=(rect.get_x() + rect.get_width() / 2, yht/4),
                    xytext=(0, v0),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom',
                    color = clr)

def make_mitigation_bars(ax,k,pos,width,colors,error=None):
    ax.set_ylabel('Mitigation\n(% of Damage)')
    rects3 = ax.bar(pos, k, 1.25 * width, color = colors[2], yerr = error,capsize = 6)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.0f'))
    autolabel(rects3,ax,'white',False)

def make_damage_bars(ax,d0,d1,name,pos,width,colors,lloc, error = [None,None],legsize=6):
    ax.xaxis.tick_top()
    rects1 = ax.bar(pos - width/2, d0, width, label='Innovation', color = colors[0],yerr=error[0],capsize =6)
    rects2 = ax.bar(pos + width/2, d1, width, label='No Innovation', color = colors[1],yerr=error[1],capsize=6)
    
    ax.invert_yaxis()
    
    
    ax.set_ylabel('Damage\n(% of Total Value)')
    ax.set_xticks(pos)
    
    ax.legend(fontsize=legsize,loc=lloc)
    autolabel(rects1,ax)
    autolabel(rects2,ax)
    
    trans = mtrans.Affine2D().translate(0,6)
    for t in ax.get_xticklabels():
        t.set_transform(t.get_transform()+trans)

def make_bar_chart(models,fname='test',damage=True,width=0.35, size = (3.5,3.5),
                   lloc = 'lower left',error = None,legsize=6):
    d0 = [x[0] for x in models]
    d1 = [x[1] for x in models]
    k = [x[2] for x in models]
    name = [x[3] for x in models]
    pos = np.arange(len(name))
    

    
    if damage:
        fig,ax = plt.subplots(2,1,figsize = size,sharex = True)    
        
        if error is not None:
            mitigation_error = error[:,2]
            damage_error = [error[:,0],error[:,1]]
        else:
            mitigation_error = None
            damage_error = [None,None]
        
        make_mitigation_bars(ax[0],k,pos,width,colors,error=mitigation_error)
        make_damage_bars(ax[1],d0,d1,name,pos,width,colors,lloc,error=damage_error,legsize=legsize)
        
        ax[1].set_xticklabels(name,fontsize = 10)
        
        fig.align_ylabels(ax)

    else:
        fig,ax = plt.subplots(1,1,figsize= size)
        make_mitigation_bars(ax[0],k,pos,width,colors)
        ax.set_xticklabels(name,fontsize = 10)

    plt.tight_layout() 
    plt.savefig(fname + '.pdf')
     
    return fig,ax

##
## Main code
##
    
# Damage with innov; damage without innov; pct mitigation; name
colors = ('C0','C1','C3','k')



## Plot: Baseline Estimates, 60s
models = ((6.70,8.36,19.89,'Baseline'), # 24.47 billion
          (9.44,11.65,18.98,'Area Weighted'), # 33.05 billion
          (6.55,8.13,19.43,'Price Control')) # 23.03 billion

seFile = pd.read_excel('CI_summary.xlsx')
se = seFile.iloc[:3,3:].values
make_bar_chart(models, size = (5,3.5), fname = '../Results/Figure_7', error = 1.96 * se)


## Plot: projections for rcp45 and 60
models = ((10.71,12.62,15.15,'RCP 4.5\n2050-2059'), 
          (18.88,21.71,13.04,'RCP 4.5\n2090-2099'),
          (7.44,8.84,15.83,'RCP 6.0\n2050-2059'), 
          (21.63,25.26,14.39,'RCP 6.0\n2090-2099'))

se = seFile.iloc[3:,3:].values

make_bar_chart(models, size = (6,3.5), fname = '../Results/Figure_8', error = 1.96 * se)
    


